In [ ]:
import sys
import yaml
import tensorflow as tf
import numpy as np
import pandas as pd
import functools
from pathlib import Path
from datetime import datetime
from tqdm import tqdm_notebook as tqdm

# Plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / "anaconda3/envs/image-processing/bin/ffmpeg")

%load_ext autoreload
%autoreload 2

import dcgan
import gan_utils
from load_data import preprocess_images
from ds_utils.generative_utils import animate_latent_transition, gen_latent_linear, gen_latent_idx
from ds_utils.plot_utils import plot_sample_imgs

In [ ]:
data_folder = Path.home() / "Documents/datasets"

In [ ]:
# load model config
with open('configs/dcgan_celeba_config.yaml', 'r') as f:
    config = yaml.load(f)
HIDDEN_DIM = config['data']['z_size']
IMG_SHAPE = config['data']['input_shape']
BATCH_SIZE = config['training']['batch_size']
IMG_IS_BW = IMG_SHAPE[2] == 1
PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE
config

Data


In [ ]:
# load Fashion MNIST dataset
((X_train, y_train), (X_test, y_test)) = tf.keras.datasets.fashion_mnist.load_data()

In [ ]:
X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)

print(X_train[0].shape)
print(X_train[0].max())
print(X_train[0].min())

print(X_train.shape)

assert X_train[0].shape == tuple(config['data']['input_shape'])

In [ ]:
train_ds = tf.data.Dataset.from_tensor_slices(X_train).take(5000)
test_ds = tf.data.Dataset.from_tensor_slices(X_test).take(256)

In [ ]:
sys.path.append("../")
from tmp_load_data import load_imgs_tfdataset

In [ ]:
train_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 500, zipped=False)
test_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 100, zipped=False)

Model


In [ ]:
# instantiate GAN
gan = dcgan.DCGan(IMG_SHAPE, config)

In [ ]:
# test generator
generator_out = gan.generator.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM))
generator_out.shape

In [ ]:
# test discriminator
discriminator_out = gan.discriminator.predict(generator_out)
discriminator_out.shape

In [ ]:
# test gan
gan.gan.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM)).max()

In [ ]:
# plot random generated image
plt.imshow(gan.generator.predict([np.random.randn(1, HIDDEN_DIM)])[0]
           .reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.show()

In [ ]:
gan.generator.summary()

Training


In [ ]:
# setup model directory for checkpoint and tensorboard logs
model_name = "dcgan_celeba"
model_dir = Path.home() / "Documents/models/tf_playground/gan" / model_name
model_dir.mkdir(exist_ok=True, parents=True)
export_dir = model_dir / 'export'
export_dir.mkdir(exist_ok=True)
log_dir = model_dir / "logs" / datetime.now().strftime("%Y%m%d-%H%M%S")

In [ ]:
nb_epochs = 1000
gan._train(train_ds=gan.setup_dataset(train_ds),
            validation_ds=gan.setup_dataset(test_ds),
            nb_epochs=nb_epochs,
            log_dir=log_dir,
            checkpoint_dir=export_dir,
            is_tfdataset=True)

In [ ]:
# export Keras model (.h5)
gan.generator.save(str(export_dir / 'generator.h5'))
gan.discriminator.save(str(export_dir / 'discriminator.h5'))

In [ ]:
# plot generator results
plot_side = 5
plot_sample_imgs(lambda x: gan.generator.predict(np.random.randn(plot_side*plot_side, HIDDEN_DIM)), 
                 img_shape=PLOT_IMG_SHAPE,
                 plot_side=plot_side,
                 cmap='gray' if IMG_IS_BW else 'jet')

Explore Latent Space


In [ ]:
%matplotlib inline

In [ ]:
def gen_image_fun(latent_vectors):
    img = gan.generator.predict(latent_vectors)[0].reshape(PLOT_IMG_SHAPE)
    return img

In [ ]:
img = gen_image_fun(z_s)

In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_celeba"

nb_samples = 10
nb_transition_frames = 10
nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)

# random list of z vectors
z_s = np.random.randn(nb_samples, HIDDEN_DIM)

animate_latent_transition(latent_vectors=z_s, 
                         gen_image_fun=gen_image_fun,
                         gen_latent_fun=lambda z_s, i: gen_latent_linear(z_s, i, nb_transition_frames),
                         img_size=PLOT_IMG_SHAPE,
                         nb_frames=nb_frames,
                         render_dir=render_dir)

In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_fmnist_test"

nb_transition_frames = 10

# random list of z vectors
#rand_idx = np.random.randint(len(X_train))
z_start = np.random.randn(1, HIDDEN_DIM)
vals = np.linspace(-1., 1., nb_transition_frames)

for z_idx in range(20):
    animate_latent_transition(latent_vectors=z_start, 
                             gen_image_fun=gen_image_fun,
                             gen_latent_fun=lambda z_s, i: gen_latent_idx(z_s, i, z_idx, vals),
                             img_size=PLOT_IMG_SHAPE,
                             nb_frames=nb_transition_frames,
                             render_dir=render_dir)